from abc import ABC, abstractmethod
import pandas as pd
from robustx.lib.tasks.Task import Task

from tqdm import tqdm

from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp
import pickle, math, os
from tqdm.auto import tqdm
import pandas as pd
import threading

def _cf_chunk_worker_v2(
        idx_chunk,
        df_blob,
        task_cls,
        task_init_args,
        cfunc_blob,
        gen_cls,
        neg_value,
        column_name,
        eps,
        method_name,
        kwargs,
):
    """
    Modified worker that recreates the task instead of pickling it
    """
    import pandas as _pd
    import pickle as _p

    X_local = _p.loads(df_blob)
    cfunc = _p.loads(cfunc_blob) if cfunc_blob else None
    
    # Recreate task instead of unpickling it (avoids file descriptor issues)
    task = task_cls(*task_init_args)
    
    # Set eps on the task
    task.eps = eps
    
    # Set method-specific parameters on the task (matching main code logic)
    if method_name == "RNCE":
        task.delta = eps
    # Add other method-specific parameters as needed
    # elif method_name == "STCE":
    #     task.some_other_param = eps
    
    # Apply any additional parameters from kwargs to the task
    method_params = kwargs.pop('task_params', {})
    for param_name, param_value in method_params.items():
        setattr(task, param_name, param_value)
    
    # Fresh generator instance
    gen = gen_cls(task, custom_distance_func=cfunc)

    out_list = []
    for idx in idx_chunk:
        row = X_local.iloc[idx]
        cf = gen.generate_for_instance(
            row,
            neg_value=neg_value,
            column_name=column_name,
            **kwargs)
        cf.index = [idx] * len(cf)
        out_list.append(cf)

    return out_list


class CEGenerator(ABC):
    """
    Abstract class for generating counterfactual explanations for a given task.

    This class provides a framework for generating counterfactuals based on a distance function
    and a given task. It supports default distance functions such as Euclidean and Manhattan,
    and allows for custom distance functions.

    Attributes:
        _task (Task): The task to solve.
        __customFunc (callable, optional): A custom distance function.
    """

    def __init__(self, ct: Task, custom_distance_func=None, *args, **kwargs):
        """
        Initializes the CEGenerator with a task and an optional custom distance function.

        @param ct: The Task instance to solve.
        @param custom_distance_func: An optional custom distance function.
        """
        self._task = ct
        self.__customFunc = custom_distance_func

    @property
    def task(self):
        return self._task

    def generate(
        self,
        instances: pd.DataFrame,
        *,
        neg_value=0,
        column_name: str = "target",
        n_jobs = 1,
        chunk_size: int = None,
        use_threading: bool = False,
        eps: float = None,
        method_name: str = None,
        **kwargs,
    ) -> pd.DataFrame:
        """
        Parallel-capable generator with improved resource management.

        Parameters
        ----------
        instances : pd.DataFrame
            Rows you want counterfactuals for (target column may or may not be
            present – we ignore/drop it).
        neg_value : int, default 0
            What counts as a negative label (passed through).
        column_name : str, default "target"
            Name of the target column (passed through).
        n_jobs : int, optional
            How many worker *processes* to spawn. 1 ⇒ sequential path.
            If None, uses min(cpu_count, 4) to avoid too many open files.
        chunk_size : int, optional
            How many rows each worker handles per call.
            Tune for best throughput vs. memory in your environment.
        use_threading : bool, default False
            Use ThreadPoolExecutor instead of ProcessPoolExecutor.
            Avoids file descriptor issues but may be slower for CPU-bound tasks.
        eps : float, optional
            The epsilon value for counterfactual generation.
        method_name : str, optional
            Name of the method being used (for method-specific parameter setting).
        kwargs :
            Any extra arguments forwarded to `generate_for_instance`.

        Returns
        -------
        pd.DataFrame
            Counterfactuals, with the same index order as `instances`.
        """
        # Determine number of jobs
        
        # if n_jobs is None:
        #     # Conservative default to avoid file descriptor issues
        #     n_jobs = min(mp.cpu_count(), 4)
        # #n_jobs = 1
        n_jobs = 1
        print(f"Using {n_jobs} jobs for parallel processing.")
        
        
        # Sequential path for small datasets or single job
        #if True:#n_jobs <= 1 or len(instances) <= 40:
        dfs = [
            self.generate_for_instance(row, neg_value=neg_value,
                                    column_name=column_name, **kwargs)
            for _, row in tqdm(instances.iterrows(), total=len(instances),
                            desc="CF-sequential", disable=False, file=None)
        ]
        return pd.concat(dfs) if dfs else pd.DataFrame()
        
    def _generate_threaded(self, instances, neg_value, column_name, n_jobs, chunk_size, eps, method_name, **kwargs):
        """
        Alternative implementation using threads instead of processes.
        Avoids file descriptor issues but may be slower for CPU-bound tasks.
        """
        from concurrent.futures import ThreadPoolExecutor
        
        X_df = instances.copy(deep=False)
        if column_name in X_df.columns:
            X_df = X_df.drop(columns=[column_name])

        # Set task parameters for threading
        self.task.eps = eps
        if method_name == "RNCE":
            self.task.delta = eps

        # Prepare index chunks
        indices = list(range(len(X_df)))
        chunks = [indices[i : i + chunk_size]
                for i in range(0, len(indices), chunk_size)]

        def _thread_worker(idx_chunk):
            out_list = []
            for idx in idx_chunk:
                row = X_df.iloc[idx]
                cf = self.generate_for_instance(
                    row,
                    neg_value=neg_value,
                    column_name=column_name,
                    **kwargs)
                cf.index = [idx] * len(cf)
                out_list.append(cf)
            return out_list

        dfs_out = []
        with ThreadPoolExecutor(max_workers=n_jobs) as executor:
            futures = [executor.submit(_thread_worker, chunk) for chunk in chunks]
            
            for future in tqdm(as_completed(futures), total=len(futures),
                            desc="CF-threads"):
                dfs_out.extend(future.result())

        if not dfs_out:
            return pd.DataFrame()

        res = pd.concat(dfs_out)
        res = res.loc[instances.index]
        return res

    def _get_task_init_args(self):
        """
        Extract initialization arguments from the task.
        Override this method in subclasses if needed.
        """
        # This is a placeholder - you may need to customize this based on your Task class
        # For example, if Task takes a model and data:
        # return (self.task.model, self.task.data)
        # For now, we'll try to get common attributes
        try:
            # Common pattern: if Task has model and training_data attributes
            if hasattr(self.task, 'model') and hasattr(self.task, 'training_data'):
                return (self.task.model, self.task.training_data)
            elif hasattr(self.task, 'model'):
                return (self.task.model,)
            else:
                # Fallback: return empty args and let subclass handle it
                return ()
        except:
            return ()

    def generate_for_instance(self, instance, neg_value=0,
                              column_name="target", **kwargs) -> pd.DataFrame:
        """
        Generates a counterfactual for a provided instance.

        @param instance: The instance for which you would like to generate a counterfactual.
        @param column_name: The name of the target column.
        @param neg_value: The value considered negative in the target variable.
        @return: A DataFrame containing the counterfactual explanations for the instance.
        """
        return self._generation_method(instance, neg_value=neg_value, column_name=column_name, **kwargs)

    def generate_for_all(self, neg_value=0, column_name="target", **kwargs) -> pd.DataFrame:
        """
        Generates counterfactuals for all instances with a given negative value in their target column.

        @param neg_value: The value in the target column which counts as a negative instance.
        @param column_name: The name of the target variable.
        @return: A DataFrame of the counterfactuals for all negative values.
        """
        negatives = self.task.get_negative_instances(neg_value, column_name=column_name)
        
        # Only pass eps and method_name if they're not already in kwargs
        if 'eps' not in kwargs:
            kwargs['eps'] = getattr(self.task, 'eps', None)
        if 'method_name' not in kwargs:
            kwargs['method_name'] = self.__class__.__name__
        
        counterfactuals = self.generate(
            negatives,
            column_name=column_name,
            neg_value=neg_value,
            **kwargs
        )

        counterfactuals.index = negatives.index
        return counterfactuals

    @abstractmethod
    def _generation_method(self, instance,
                           column_name="target", neg_value=0, **kwargs):
        """
        Abstract method to be implemented by subclasses for generating counterfactuals.

        @param instance: The instance for which to generate a counterfactual.
        @param column_name: The name of the target column.
        @param neg_value: The value considered negative in the target variable.
        @return: A DataFrame containing the generated counterfactuals.
        """
        pass

    @property
    def custom_distance_func(self):
        """
        Returns custom distance function passed at instantiation
        @return: distance Function, (DataFrame, DataFrame) -> Int
        """
        return self.__customFunc